Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Feature visualisation


Introducing FlashTorch 🔦


  • Open source feature visualisation toolkit for neural nets in PyTorch

  • Supports torchvision models

  • Available to install via pip!

      $ pip install flashtorch

Image processing & CNN 101


Kernel & convolution


Kernel: a small matrix used for edge detection, blurring, sharpening, embossing, etc

Convolution: an operation to calculate weighted sum of neibouring pixels

Examples of convolution: detecting edges


Typical CNN architecture


Kernels weights are learnt during the training to extract relevant features from input images.

Feature visualisation technique

Saliency maps


Saliency


  • A subjective quality in human visual perception

  • Makes certain items stand out and grabs our attention

Saliency maps in computer vision: indications of the most “salient” regions

Saliemcy maps in CNNs


  • First introduced in 2013

  • Gradients of target class w.r.t. input image via backpropagation

  • Pixels with positive gradients: some intuition of attention

  • Avaialble via flashtorch.saliency API

FlashTorch demo 1

Visualising saliency maps with backpropagation


Install FlashTorch & load an image



$ pip install flashtorch

...
In [2]:
from flashtorch.utils import load_image

image = load_image('../../examples/images/great_grey_owl.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Apply transformations


In [3]:
from flashtorch.utils import apply_transforms, denormalize, format_for_plotting

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Create a Backprop object with a pre-trained model


In [4]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
Signature:

    backprop.calculate_gradients(input_, target_class=None, ...)

Calculate the gradients of target class w.r.t. input


In [5]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(f'Traget class index: {target_class}')

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
Traget class index: 24
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's visualise gradients


In [6]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)
Pixels where the animal is present have the strongest positive effects.
But it's quite noisy...

FlashTorch demo 2

Visualising saliency maps with guided backpropagation


Guided backpropagation


  • Additional guidance from the higher layers during backprop

  • Masks out neurons that had no effect or negative effects on the prediction

  • Preventing the flow of such gradients: less noise

Calculate the gradients with guided backprop


In [7]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)

max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)

visualize(input_, guided_gradients, max_guided_gradients)
Now that's much less noisy!
Pixels around the head and eyes have the strongest positive effects.

What about a peacock...


In [9]:
visualize(input_, guided_gradients, max_guided_gradients)

... or a tucan?


In [11]:
visualize(input_, guided_gradients, max_guided_gradients)

FlashTorch demo 3

Gaining additional insights on transfer learning


Transfer learning


  • A model developed for a task is reused as a starting point for another task

  • Often used in computer vision & natural language processing tasks

  • Save compute & time resources

Building a flower classifier


<-- From: Densenet model, pre-trained on ImageNet (1000 classes)

--> To: Flower classifier to recognise 102 species of flowers (dataset).

Pre-trained model gets it very wrong! - test accuracy 0.1%


In [14]:
backprop = Backprop(pretrained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:94: UserWarning: The predicted class index 1 does notequal the target class index 96. Calculatingthe gradient w.r.t. the predicted class.
  'the gradient w.r.t. the predicted class.'

Trained model - test accuracy 98.7%


In [15]:
backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
Trained model pays attention specifically to the distinguising pattern of this particular specie
In line with what we would focus!

Thank you!